#pragma once

/*
* http server for anet implementation
* http�ķ�������
* Copyright (C) 2021-2022 gavingqf(gavingqf@126.com)
*
*    Distributed under the Boost Software License, Version 1.0.
*    (See accompanying file LICENSE_1_0.txt or copy at
*    http://www.boost.org/LICENSE_1_0.txt)
*/

/*
 * only support get and post mode now.
 * http server ����֧��get��post�ӿڣ������Ķ����ش���
 */

#include <map>
#include <functional>
#include <utility>
#include "event_loop.hpp"
#include "server.hpp"
#include "log.h"
#include "http_info.hpp"
#include "asio/detail/noncopyable.hpp"
#include "anet.hpp"

 // if not windows, just define stricmp and strnicmp
#ifndef _WIN32
#define stricmp strcasecmp
#define strnicmp strncasecmp
#endif

namespace anet {
	namespace http {
		// http response struct.
		struct HttpRes {
			// method
			std::string method{ "" };

			// path
			std::string path{ "" };

			// version
			std::string version{ "" };

			// content type
			std::string contentType{ "" };

			// content
			std::string content{ "" };

			// connection
			std::string connectionMode{ "" };

			// ip and port
			std::string remoteIP{ "0" };
			unsigned short remotePort{ 0 };

			auto getMethod() const {
				return this->method;
			}

			auto getClientIP() const {
				return this->remoteIP;
			}
			auto getClientPort() const {
				return this->remotePort;
			}

			auto getContent() const {
				return this->content;
			}

			auto getVersion() const {
				return this->version;
			}

			auto getPath() const {
				return this->path;
			}

			auto getContentType() const {
				return this->contentType;
			}

			// parse info.
			bool parse(const std::string &content) {
				if (0 == strnicmp(content.c_str(), gHttpPostMethod, strlen(gHttpPostMethod))) {
					// == post parse == //
					auto pos = content.find(gHttpSpaceLine);
					assert(pos != gHttpStrFindInvalid && "invalid content");
					if (pos == gHttpStrFindInvalid) {
						return false;
					}

					// require content.
					this->content = std::move(std::string(&content[pos + gHttpSpaceLineLen]));

					// try to parse method and body, and callback.
					std::vector<std::string> vecContents;
					strSplits(content, gHttpCRLF, vecContents);

					{// parse content type.
						auto find = false;
						for (size_t i = 0; i < vecContents.size(); i++) {
							auto p = vecContents[i].find(gHttpContentTypeLine);
							if (gHttpStrFindInvalid != p) {
								this->contentType = std::move(std::string(&vecContents[i][p + strlen(gHttpContentTypeLine)]));
								find = true;
								break;
							}
						}
						if (!find) {
							this->contentType = gHttpContentTypeJson;
						}
					}

					{// parse Connection flag.
						auto find = false;
						for (size_t i = 0; i < vecContents.size(); i++) {
							auto p = vecContents[i].find(gHttpConnectionLine);
							if (gHttpStrFindInvalid != p) {
								this->connectionMode = std::move(std::string(&vecContents[i][p + strlen(gHttpConnectionLine)]));
								find = true;
								break;
							}
						}
						if (!find) {
							this->connectionMode = gHttpConnectionClose;
						}
					}

					// parse the first data.
					std::vector<std::string> vec;
					strSplits(vecContents[0], " ", vec);

					// build http parameter.
					if (vec.size() < 3) {
						return false;
					}

					this->method = std::move(vec[0]);
					this->path = std::move(vec[1]);
					this->version = std::move(vec[2]);
					return true;
				} else if (0 == strnicmp(content.c_str(), gHttpGetMethod, strlen(gHttpGetMethod))) {
					// == get parse == //
					// try to parse method and body, and callback.
					std::vector<std::string> vecContents;
					strSplits(content, gHttpCRLF, vecContents);

					{// parse content type.
						auto find = false;
						for (size_t i = 0; i < vecContents.size(); i++) {
							auto p = vecContents[i].find(gHttpContentTypeLine);
							if (gHttpStrFindInvalid != p) {
								this->contentType = std::move(std::string(&vecContents[i][p + strlen(gHttpContentTypeLine)]));
								find = true;
								break;
							}
						}
						if (!find) {
							this->contentType = gHttpContentTypeJson;
						}
					}

					{// parse Connection flag.
						auto find = false;
						for (size_t i = 0; i < vecContents.size(); i++) {
							auto p = vecContents[i].find(gHttpConnectionLine);
							if (gHttpStrFindInvalid != p) {
								this->connectionMode = std::move(std::string(&vecContents[i][p + strlen(gHttpConnectionLine)]));
								find = true;
								break;
							}
						}
						if (!find) {
							this->connectionMode = gHttpConnectionClose;
						}
					}

					// parse the first data.
					std::vector<std::string> vec;
					strSplits(vecContents[0], " ", vec);

					// build http parameter.
					if (vec.size() < 3) {
						return false;
					}

					this->method = std::move(vec[0]);
					std::vector<std::string> vecMethod;
					strSplits(vec[1], "?", vecMethod);
					if (vecMethod.size() < 2) {
						this->path = std::move(vecMethod[0]);
						this->content = "";
					} else {
						this->path = std::move(vecMethod[0]);
						this->content = std::move(vecMethod[1]);
					}
					this->version = std::move(vec[2]);
					return true;
				} else {
					return false;
				}
			}
		};

		struct HttpRes;
		typedef std::pair<httpStatus, std::string> httpResPair;

		// http res callback typedef.
		typedef std::function<httpResPair(const HttpRes&)> httpResCallback;
		typedef std::function<httpResPair(const HttpRes&, std::string)> httpHandleDealer;

		// http handler typedef.
		typedef std::map<std::string, httpResCallback> httpResHandlerMap;

		// http server's session.
		class CHttpServerSession final : public anet::tcp::ISession {
		public:
			explicit CHttpServerSession(const httpHandleDealer& httpResDealer) :
				m_resCall(httpResDealer) {
			}
			virtual ~CHttpServerSession() = default;

		public:
			virtual void onConnected(anet::tcp::connSharePtr conn) override {
				m_conn = conn;
				m_conn->setLinger(0);
				m_close = false;
			}

			// http data parser.
			virtual void onRecv(const char *data, int len) override {
				HttpRes httpRes;
				httpRes.remoteIP = m_conn->getRemoteAddr();
				httpRes.remotePort = m_conn->getRemotePort();

				// parse http data.
				auto ret = httpRes.parse(std::string(data, data + len));
				assert(ret && "parse http client error");
				if (!ret) {
					return;
				}
				auto resPair = m_resCall(httpRes, httpRes.path);

				// return result back.
				this->sendHttpRes(httpRes, resPair);
			}

			virtual void onTerminate() override {
				m_close = true;
			}

			virtual void release() override {
				delete this;
			}

			bool isClose() const {
				return m_close;
			}

		private:
			// send http response.
			void sendHttpRes(const HttpRes &httpReq, httpResPair retPair) {
				std::string res;
				res.reserve(m_sendBuffSize);

				// http version info
				res += httpReq.getVersion();
				res += " ";
				res += std::to_string(int(retPair.first));
				res += " ";
				res += getStatusDesc(retPair.first);
				res.append(gHttpCRLF, gHttpCRLFLen);

				// Content type
				if (!httpReq.getContentType().empty()) {
					res.append(gHttpContentTypeLine, strlen(gHttpContentTypeLine));
					res.append(httpReq.getContentType());
					res.append(gHttpCRLF, gHttpCRLFLen);
				} else {
					res.append(gHttpContentTypeLine, strlen(gHttpContentTypeLine));
					res.append(gHttpContentTypeUrlEncoded);
					res.append(gHttpCRLF, gHttpCRLFLen);
				}

				// connection close flag.
				res.append(gHttpConnectionLine, strlen(gHttpConnectionLine));
				res.append(gHttpConnectionClose, strlen(gHttpConnectionClose));
				res.append(gHttpCRLF, gHttpCRLFLen);

				// content len
				std::string dataLen(gHttpContentLenLine);
				dataLen += std::to_string(retPair.second.size());
				res += dataLen;
				res.append(gHttpSpaceLine, gHttpSpaceLineLen);

				// content.
				res += retPair.second;

				// send all contents.
				this->send(res);

				// try to close it if there is "close" command.
				if (0 == strnicmp(httpReq.connectionMode.c_str(), 
					gHttpConnectionClose, strlen(gHttpConnectionClose))
					) {
					m_conn->close();
				}
			}

		protected:
			void send(const std::string &res) {
				if (isClose()) return;
				m_conn->send(res);
			}

		private:
			anet::tcp::connSharePtr m_conn{ nullptr };
			httpHandleDealer m_resCall;
			std::atomic_bool m_close{ true };
			static const int m_sendBuffSize{ 1024 };
		};

		// http server's session factory.
		class CHttpSessionFactory final : public anet::tcp::ISessionFactory {
		public:
			explicit CHttpSessionFactory(const httpHandleDealer& httpDealer) :
				m_resDealer(httpDealer) {
			}
			virtual ~CHttpSessionFactory() = default;
			CHttpSessionFactory(const CHttpSessionFactory&) = delete;
			CHttpSessionFactory& operator= (const CHttpSessionFactory&) = delete;

		public:
			virtual anet::tcp::ISession *createSession() override {
				return new CHttpServerSession(m_resDealer);
			}

		private:
			httpHandleDealer m_resDealer;
		};

		// http server codec.
		class CHttpServerCodec final : public anet::tcp::ICodec {
		public:
			CHttpServerCodec() = default;
			virtual ~CHttpServerCodec() = default;

		public:
			virtual int parsePacket(const char *data, int len) override {
				std::string strData(data, len);
				auto pos = strData.find(gHttpSpaceLine);
				if (pos == gHttpStrFindInvalid) {
					return  anet::tcp::retError;
				}

				// check post or get method.
				if (0 == strnicmp(data, gHttpPostMethod, strlen(gHttpPostMethod))) {
					auto lenPos = strData.find(gHttpContentLenLine);
					if (lenPos == gHttpStrFindInvalid) {
						return anet::tcp::retError; // just close it.
					}

					// content len flag string.
					std::string conLenString(&data[lenPos + gHttpContentLenLineLen]);
					auto contentLen = std::atoi(conLenString.c_str());
					if (len >= int(pos) + int(gHttpSpaceLineLen) + contentLen) {
						return int(pos + gHttpSpaceLineLen + contentLen);
					} else {
						return anet::tcp::retNotComplete;
					}
				} else if (0 == strnicmp(data, gHttpGetMethod, strlen(gHttpGetMethod))) {
					return int(pos + gHttpSpaceLineLen);
				} else {
					return anet::tcp::retError; // only support get or post, else others just close it.
				}
			}
		};


		// http server.
		class CHttpServer final : asio::noncopyable {
		public:
			explicit CHttpServer(int size = 0) : m_loop(size), m_server(m_loop), 
				m_factory(nullptr), m_codec(nullptr) {
			}
			virtual ~CHttpServer() { 
				release(); 
			}

		public:
			// start listens at addr:port, where addr is ip address, and "0" is "0.0.0.0"
			bool start(const std::string &addr, unsigned short port) {
				auto methodHandler = std::bind(&CHttpServer::httpReqDo, this, std::placeholders::_1, std::placeholders::_2);
				m_factory = new CHttpSessionFactory(methodHandler);
				m_server.setSessionFactory(m_factory);
				m_codec = new CHttpServerCodec();
				m_server.setPacketParser(m_codec);
				return m_server.start(addr, port);
			}
			// start listen with port.
			bool start(unsigned short port) {
				return start("0", port);
			}

			// register post handler.
			bool post(std::string method, httpResCallback callback) {
				if (m_postHandlers.find(method) != m_postHandlers.end()) {
					return false;
				} else {
					m_postHandlers[method] = std::move(callback);
					return true;
				}
			}

			// register get handler. 
			bool get(std::string method, httpResCallback callback) {
				if (m_getHandlers.find(method) != m_getHandlers.end()) {
					return false;
				} else {
					m_getHandlers[method] = std::move(callback);
					return true;
				}
			}

			void release() {
				if (m_factory != nullptr) {
					delete m_factory;
					m_factory = nullptr;
				}
				if (m_codec != nullptr) {
					delete m_codec;
					m_codec = nullptr;
				}
			}

		protected:
			httpResPair httpReqDo(const HttpRes &httpReq, const std::string &method) {
				if (0 == stricmp(httpReq.getMethod().c_str(), gHttpGetMethod)) {
					return httpGetReqDo(httpReq, method);
				} else if (0 == stricmp(httpReq.getMethod().c_str(), gHttpPostMethod)) {
					return httpPostReqDo(httpReq, method);
				} else {
					return { httpStatus::BAD_REQUEST,"" };
				}
			}

			// http post req callback. return handle result({errorId, string})
			httpResPair httpPostReqDo(const HttpRes &httpReq, const std::string &method) {
				auto it = m_postHandlers.find(method);
				if (it != m_postHandlers.end()) {
					return it->second(httpReq);
				} else {
					return { httpStatus::BAD_REQUEST,"" };
				}
			}

			// http get req handler.return handle result({errorId, string})
			httpResPair httpGetReqDo(const HttpRes &httpReq, const std::string &method) {
				auto it = m_getHandlers.find(method);
				if (it != m_getHandlers.end()) {
					return it->second(httpReq);
				} else {
					return { httpStatus::BAD_REQUEST,"" };
				}
			}

		protected:
			// event loop.
			anet::tcp::CEventLoop m_loop;

			// tcp server.
			anet::tcp::CServer m_server;

			// post handlers map.
			httpResHandlerMap m_postHandlers;

			// get handlers map.
			httpResHandlerMap m_getHandlers;

			// session factory and codec.
			anet::tcp::ISessionFactory* m_factory{ nullptr };
			anet::tcp::ICodec* m_codec{ nullptr };
		};
	}
}
